Skip to content

Conversation

@cpersson-amd
Copy link

This PR implements the following:

  • TransformerEngine flash attention for WAN training and inference.
  • A new fsdp sharding parallelism optimized for use on GPUs.
  • Some minor changes to allow for training on flax version 0.11.2.

The code has been tested on WAN 2.1 (training and inference) and flux (only training) using GPUs.

@google-cla
Copy link

google-cla bot commented Dec 16, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@cpersson-amd cpersson-amd marked this pull request as draft December 17, 2025 00:18
@cpersson-amd cpersson-amd marked this pull request as ready for review December 17, 2025 10:21
@cpersson-amd cpersson-amd reopened this Dec 17, 2025
@entrpn
Copy link
Collaborator

entrpn commented Dec 30, 2025

@cpersson-amd I've been out on PTO for a month. I'll take a closer look at this next week. Meanwhile, can you update your branch with the latest in main. Thanks.

Copy link
Collaborator

@entrpn entrpn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general the PR looks good, but I'm still unsure if adding another axes, fsdp_batch, is really necessary. I would prefer not to add it. The other major thing is switching the mesh_axes from data, fsdp, tensor to data, tensor, fsdp.

@entrpn entrpn requested a review from susanbao January 15, 2026 00:21
@entrpn
Copy link
Collaborator

entrpn commented Jan 15, 2026

@susanbao can you take a quick look at this PR.

@entrpn
Copy link
Collaborator

entrpn commented Jan 16, 2026

@cpersson-amd please review Sanbao's comments above and rebase with main. We tested the PR internally and it looks good. Would you be willing to change the axis fsdp to context? If not, I can make the change after this PR is merged.

@cpersson-amd
Copy link
Author

@entrpn I've rebased with main, included @susanbao requested change and updated the mesh names: fsdp -> context, fsdp_batch -> fsdp. Please let me know if anything else needs to be changed.

@entrpn
Copy link
Collaborator

entrpn commented Jan 20, 2026

@entrpn I've rebased with main, included @susanbao requested change and updated the mesh names: fsdp -> context, fsdp_batch -> fsdp. Please let me know if anything else needs to be changed.

thanks @cpersson-amd this looks great. Can you run ruff check --fix as the unit tests are failing due to formatting right now.

@cpersson-amd
Copy link
Author

cpersson-amd commented Jan 20, 2026

@entrpn Sure, I ran 'ruff check --fix' and had to manually fix some bare except statements. It should be good with the latest commit

@entrpn
Copy link
Collaborator

entrpn commented Jan 20, 2026

@cpersson-amd Please review my PR to fix some of the unit tests. Once they pass, this can be merged. cpersson-amd#1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants